# -*- coding: utf-8 -*-
"""Supplementary_LCB_Bonus.ipynb

Automatically generated by Colaboratory.
"""

'''
Simulate toy MDP, offline dataset, and policy to evaluate, for discovering
examples of optimistic LCB bonus.
'''
import numpy as np
from scipy import linalg


dist = np.random.standard_normal
d_s = d_a = 30
scale = 1.
num_episodes = 5
steps_per_episode = 5
gamma = 0.5
EPS = 1e-8


def check_mdp_result(seed):
  # Set the seed
  np.random.seed(seed)

  # Create the data matrices
  X = []
  Xprime = []
  for _ in range(num_episodes):
    s = dist((d_s)) / scale
    for _ in range(steps_per_episode):
      a = dist((d_a)) / scale
      sprime = dist((d_s)) / scale
      aprime = a
      X.append(np.concatenate([s, a]))
      Xprime.append(np.concatenate([sprime, aprime]))
      s = sprime
  X = np.array(X)
  Xprime = np.array(Xprime)
  R = dist((X.shape[0], 1))

  A = X @ X.T + EPS * np.eye(X.shape[0]) # to avoid numerical issues with matrix inversion
  C = Xprime @ X.T @ np.linalg.inv(A)

  # Find largest singular value of gamma * C
  singular_value = np.linalg.svd(C)[1][0] * gamma

  # Compute stddev for LCB, i.e., sqrt(E[(Q0(X') - C * Q0(X))^2]).
  # We assume Q-functions are linear function approximators
  # We assume the initial weight distribution is a spherical Gaussian with dimension d_s + d_a.
  # Back of the envelope calculations using this assumption and our derived equations leads to,
  Xdiff = Xprime - C @ X
  Xdiff_std = np.sqrt(np.sum(Xdiff ** 2, -1))

  # if singular_value < 1.:
  # Compute mean & penalty for LCB.
  t = 1000

  cum_C = np.eye(X.shape[0])
  for _ in range(t):
    cum_C = np.eye(X.shape[0]) + gamma * C @ cum_C

  lcb_mean = cum_C @ C @ R
  lcb_penalty = cum_C @ Xdiff_std

  # If min is < 0, that means there exists a (s', \pi(s')) where the LCB penalty is actually a bonus.
  min_lcb = np.min(lcb_penalty)
  max_lcb = np.max(lcb_penalty)

  # print(f'{seed}: {min_lcb}, {singular_value}')
  return (min_lcb < 0.), (singular_value < 1.)

num_good_singular_values = 0
bad_lcbs = 0
num_good_examples = 0
for i in range(1000):
  bad_lcb, good_sv = check_mdp_result(i)
  bad_lcbs += int(bad_lcb)
  num_good_singular_values += int(good_sv)
  num_good_examples += int(bad_lcb and good_sv)

print(f'Number of examples found: {num_good_examples}')
